import time
import serial
import serial.tools.list_ports
from scipy import signal
from pyampd.ampd import find_peaks
import tkinter as tk
from tkinter import messagebox

root = tk.Tk()
root.withdraw()


class Device:
    def __init__(self, sample_rate=100, avg_rate=10, brate=115200):
        self.valid = False
        self.brate = brate
        self.port_list = [i.device for i in serial.tools.list_ports.comports()]
        if len(self.port_list) > 0:
            self.port = self.port_list[0]
        else:
            messagebox.showerror(title="设备错误", message="未发现可用设备！")
            return
        self.view_mode = "Dynamic"
        self.filter_mode = "Filter OFF"
        self.ser = serial.Serial(self.port, brate)
        self.valid = True
        self.sampe_rate = sample_rate
        self.h_b, self.h_a = signal.butter(8, 0.1, "highpass")
        self.l_b, self.l_a = signal.butter(8, 0.3, "lowpass")
        self.read_data = ""
        self.data_valid = 0
        self.k = 0
        self.data_list_time = []
        self.hr_time_start_end = [0, 0]
        self.range_list_time = list(range(self.sampe_rate))
        self.ir_list_data = []
        self.red_list_show_data = list(range(self.sampe_rate))
        self.red_list_show_index = 0
        self.red_list_data = []
        self.bpm_list_data = []
        self.SPO2_list_data = []
        self.ir_list_data_filtered = []
        self.red_list_data_filtered = []
        self.timer = time.time()
        self.data_update_flag = False
        self.avg_rate = avg_rate
        self.data = {
            "red": 0,
            "ir": 0,
            "HR": 0,
            "SPO2": 0,
            "Time": 0,
            "Fin": 0,
        }

    def _culculate_spo2(self):
        if self.data_valid:
            ir_dc = min(self.ir_list_data)
            red_dc = min(self.red_list_data)
            ir_ac = max(self.ir_list_data) - ir_dc
            red_ac = max(self.red_list_data) - red_dc
            R2 = (red_ac * ir_dc) / (red_dc * ir_ac)
            SPO2 = -45.060 * R2 * R2 + 30.354 * R2 + 94.845
            if SPO2 > 100 or SPO2 < 0:
                SPO2 = 0
            return SPO2
        return 0

    def _culculate_HR(self):
        if self.data_valid and len(self.red_list_data) >= self.sampe_rate:
            try:
                peaks = find_peaks(self.red_list_data)
                self.k = peaks[-1]
                self.hr_time_start_end = [peaks[-3], peaks[-2]]
            except:
                return 0
            if len(peaks) > 3:
                HR = 60000 / (
                    self.data_list_time[self.hr_time_start_end[1]]
                    - self.data_list_time[self.hr_time_start_end[0]]
                )
                return HR
        return 0

    def get_data(self):
        if self.valid:
            if time.time() - self.timer > 1:
                self.data_update_flag = True
                self.timer = time.time()

            if (
                len(self.ir_list_data_filtered) >= self.sampe_rate
                and len(self.red_list_data_filtered) >= self.sampe_rate
                and len(self.red_list_show_data) >= self.sampe_rate
            ):
                self.data_valid = 1

            while True:
                try:
                    self.read_data = self.ser.readline().decode("utf-8")
                    break
                except:
                    self.read_data = self.ser.readline().decode("utf-8")
            if "[DATA]" in self.read_data:
                data_list_str = self.read_data.split("]")[1].strip().split(",")
                self.data["red"] = int(data_list_str[0].split("=")[1])
                self.data["ir"] = int(data_list_str[1].split("=")[1])
                self.data["Time"] = int(data_list_str[2].split("=")[1])

                if len(self.bpm_list_data) > self.avg_rate:
                    self.bpm_list_data.pop(0)
                    if self.data_update_flag:
                        self.data["HR"] = sum(self.bpm_list_data) / len(
                            self.bpm_list_data
                        )
                self.bpm_list_data.append(self._culculate_HR())

                if len(self.SPO2_list_data) > self.avg_rate:
                    self.SPO2_list_data.pop(0)
                    if self.data_update_flag:
                        self.data["SPO2"] = sum(self.SPO2_list_data) / len(
                            self.SPO2_list_data
                        )
                self.SPO2_list_data.append(self._culculate_spo2())

                if len(self.ir_list_data) > self.sampe_rate:
                    self.ir_list_data.pop(0)
                    data = signal.filtfilt(
                        self.l_b, self.l_a, self.ir_list_data, axis=0
                    )
                    data = signal.filtfilt(
                        self.h_b, self.h_a, data, axis=0
                    )
                    data = signal.detrend(
                        data,
                        axis=0,
                        type="linear",
                        bp=0,
                        overwrite_data=False,
                    )
                    self.ir_list_data_filtered = data
                self.ir_list_data.append(self.data["ir"])

                if len(self.red_list_data) > self.sampe_rate:
                    self.red_list_data.pop(0)
                    data = signal.filtfilt(
                        self.l_b, self.l_a, self.red_list_data, axis=0
                    )
                    data = signal.filtfilt(
                        self.h_b, self.h_a, data, axis=0
                    )
                    data = signal.detrend(
                        data,
                        axis=0,
                        type="linear",
                        bp=0,
                        overwrite_data=False,
                    )
                    self.red_list_data_filtered = data
                self.red_list_data.append(self.data["red"])

                if self.red_list_show_index < self.sampe_rate:
                    self.red_list_show_data[self.red_list_show_index] = self.data["red"]
                    self.red_list_show_index += 1
                    if self.red_list_show_index >= self.sampe_rate:
                        self.red_list_show_index = 0

                if len(self.data_list_time) > self.sampe_rate:
                    self.data_list_time.pop(0)
                self.data_list_time.append(self.data["Time"])

                if self.data["ir"] > 150000:
                    self.data["Fin"] = 1
                else:
                    self.data["Fin"] = 0

                self.data_update_flag = False
        return self.data


if __name__ == "__main__":
    max_divice = Device()
    while True:
        data = max_divice.get_data()
        print(max_divice.data)
